import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as td
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from pbb.models import NNet4l, trainNNet_NTK,CNNet4l, ProbNNet4l, ProbCNNet4l, ProbCNNet9l, CNNet9l, CNNet13l, ProbCNNet13l, ProbCNNet15l, CNNet15l, trainNNet, testNNet, Lambda_var, trainPNNet, computeRiskCertificates, testPosteriorMean, testStochastic, testEnsemble
from pbb.bounds import PBBobj
from pbb import data

def runexp(name_data, objective, prior_type, model, sigma_prior, pmin, learning_rate, momentum, 
learning_rate_prior=0.01, momentum_prior=0.95, delta=0.025, layers=9, delta_test=0.01, mc_samples=1000, 
samples_ensemble=100, kl_penalty=1, initial_lamb=6.0, train_epochs=100, prior_dist='gaussian', 
verbose=False, device='cuda', prior_epochs=20, dropout_prob=0.2, perc_train=1.0, verbose_test=False, 
perc_prior=0.2, batch_size=250,shot_per_class=60):
    # this makes the initialised prior the same for all bounds
    torch.manual_seed(7)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    loader_kargs = {'num_workers': 0,
                    'pin_memory': True} if torch.cuda.is_available() else {}

    train, test = data.loaddataset(name_data)
    rho_prior = math.log(math.exp(sigma_prior)-1.0)
    
    if prior_type == 'rand':
        dropout_prob = 0.0
    
    net0 = NNet4l(dropout_prob=dropout_prob, device=device).to(device)
    train_loader, test_loader, _, val_bound_one_batch, _, val_bound = data.loadbatches(
            train, test, loader_kargs, batch_size, prior=False, perc_train=perc_train, perc_prior=perc_prior,shot_per_class = shot_per_class)
    #errornet0 = testNNet(net0, test_loader, device=device)
    posterior_n_size = shot_per_class*10
    bound_n_size = shot_per_class*10  
    classes = len(train_loader.dataset.classes)
    net = ProbNNet4l(rho_prior, prior_dist=prior_dist,
                        device=device, init_net=net0).to(device)
    bound = PBBobj(objective, pmin, classes, delta,
                    delta_test, mc_samples, kl_penalty, device, n_posterior = posterior_n_size, n_bound=bound_n_size)
    optimizer_lambda = None
    lambda_var = None
    
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)
    ntk_list=[]
    for epoch in trange(train_epochs):
        if epoch == 0:
            grads_left_x,ll,rr= trainNNet_NTK(net, 1, val_bound_one_batch, device=device, verbose=verbose)
            ntk_list.append(np.linalg.norm(grads_left_x.cpu().numpy()))
            print(ntk_list)
        trainPNNet(net, optimizer, bound, epoch, train_loader, lambda_var, optimizer_lambda, verbose)
        if  epoch == 49 or epoch ==99:
            grads_left_x,ll,rr= trainNNet_NTK(net, 1, val_bound_one_batch, device=device, verbose=verbose)
            ntk_list.append(np.linalg.norm(grads_left_x.cpu().numpy()))
            #NTK_x = ((torch.einsum('nc,mc->nm', [grads_left_x, grads_left_x])))
            
        

    return ntk_list
        
    
    


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
